import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap, Normalize, PowerNorm
from matplotlib.animation import FuncAnimation, PillowWriter
import math
from itertools import cycle

def heatmap_gif(Z, trajectories, args):
    # color and style of lines and dots
    # line_styles = ['-', '--', '-.', ':']  
    line_styles = ['-'] 
    colors = [
        "#32CD32",  # Lime green
        "#FFD700",  # Gold
        "#8A2BE2",  # Blue violet
        "#FF1493",  # Deep pink
        "#00CED1",  # Dark turquoise
        "#FF4500"   # Orange red
        ]   
    style_cycle = cycle(line_styles)
    # color_cycle = cycle(plt.cm.tab10.colors)
    color_cycle = cycle(colors)
    color = {name: next(color_cycle) for name in trajectories.keys()}
    line_style = {name: next(style_cycle) for name in trajectories.keys()}
    # 创建热力图
    fig, ax = plt.subplots(figsize=(8, 8))  # 图像大小
    heatmap = ax.imshow(Z, extent=(args.xmin, args.xmax, args.ymin, args.ymax), origin='lower', cmap='coolwarm', norm=PowerNorm(gamma=args.gamma,vmax=args.vmax, vmin= args.vmin), interpolation='bilinear')
    # heatmap = ax.imshow(Z, extent=(args.xmin, args.xmax, args.ymin, args.ymax), origin='lower', cmap='coolwarm',norm=CustomNormalize(0,10), interpolation='bilinear')
    if args.color_bar:
        cbar = plt.colorbar(heatmap)  # 添加颜色条
        # cbar.set_ticks(np.linspace(0, 10, num=6))  # 设置 0 到 10 的刻度
        # cbar.set_ticklabels([f'{i:.1f}' for i in np.linspace(0, 10, num=6)])

    # 设置边距
    plt.subplots_adjust(left=0.05, right=0.98, top=1, bottom=0.01)  # 根据需要调整这些参数
    
    # 初始化点和路径
    points = {name: ax.plot([], [], marker='o', markersize=20, label=name, color=color[name], zorder=5)[0] for name in trajectories.keys()}
    paths = {name: ax.plot([], [], linestyle=line_style[name], color=color[name], zorder=4)[0] for name in trajectories.keys()}

    # 标记 start_point
    start_x, start_y = args.start_point
    ax.scatter(start_x, start_y, color='k', s=400, label='Start Point', zorder=16)  # 标记起始点
    # ax.text(start_x, start_y, 'Start Point', color='k', fontsize=12, ha='left', zorder=16)  # 添加标签
    # 绘制点时，使用 scatter 标记所有最优点
    for (x, y) in args.optimal_points:
        ax.scatter(x, y, color='k', marker='x', s=400, zorder=16)  # 用红色标记点，大小为100
    # 为了避免多个图例项，添加一个隐藏点作为图例
    ax.scatter([], [], color='k', marker='x',s=400, label='Optimal Point', zorder=16)

    if args.pic_title:
        ax.set_title(args.pic_title)
    if args.legend:
        ax.legend(loc='lower left', fontsize=24, bbox_to_anchor=(0, 0), frameon=False, handletextpad=0.5,  # 增加每个图例标签与图例边界的间距
        labelspacing=0.5,  # 增加标签之间的间距
        handlelength=0)  
    if args.axis:
        ax.set_xlabel('x')
        ax.set_ylabel('y')
    ax.grid(False)  # 关闭网格

    if args.auto_fit:
        ax = plt.gca()  # 获取当前坐标轴对象
        ax.set_aspect(1.0 * ((args.xmax-args.xmin) /(args.ymax-args.ymin) ))  # 调整y轴比例

    # 更新函数
    def update(frame):
        for name, trajectory in trajectories.items():
            trajectory= trajectory.T
            if frame < len(trajectory):
                x_val, y_val= trajectory[frame][:2]
                points[name].set_data([x_val], [y_val])  # 确保传入的是列表
                # # 更新路径数据
                path_x = [pt[0] for pt in trajectory[:frame + 1]]
                path_y = [pt[1] for pt in trajectory[:frame + 1]]
                paths[name].set_data(path_x, path_y)
        # 返回所有点和路径的 Artist 对象
        return list(points.values()) + list(paths.values())

    # 创建动画
    ani = FuncAnimation(fig, update, frames=max(len(value.T) for value in trajectories.values()), blit=True)
    # ani = FuncAnimation(fig, update, frames=len(trajectory), blit=True)

    # 获取最后一帧的内容并保存为 PDF
    if args.traj_type=="trace_and_dynamics":
        final_frame = max(len(value.T) for value in trajectories.values()) - 1
        update(final_frame)  # 手动调用 update 绘制最后一帧
        plt.savefig(f'../image/{args.task_name}/{args.id}_heatmap_{args.xmin}_{args.xmax}_{args.xnum}_{args.ymin}_{args.ymax}_{args.ynum}_{args.vmin}_{args.vmax}_{args.vlevel}.pdf')  # 保存为 PDF
        
    # 保存为GIF
    ani.save(f'../gif/{args.task_name}/{args.id}_heatmap_{args.xmin}_{args.xmax}_{args.xnum}_{args.ymin}_{args.ymax}_{args.ynum}_{args.vmin}_{args.vmax}_{args.vlevel}.gif', writer=PillowWriter(fps=args.image_per_second))




# 色彩映射函数 
# class CustomNormalize(Normalize):
#     def __init__(self, vmin=None, vmax=None, clip=False):
#         super().__init__(vmin, vmax, clip)
        
#     def __call__(self, value, clip=None):
#         if clip is None:
#             clip = self.clip

#         result, is_scalar = self.process_value(value)

#         self.autoscale_None(result)
#         vmin, vmax = self.vmin, self.vmax
#         if vmin > vmax:
#             raise ValueError("minvalue must be less than or equal to maxvalue")
#         elif vmin == vmax:
#             result.fill(0)
#         else:
#             if clip:
#                 mask = np.ma.getmask(result)
#                 result = np.ma.array(np.clip(result.filled(vmax), vmin, vmax),
#                                      mask=mask)
#             resdat = result.data
#             resdat[(resdat >= 0) & (resdat < 0.45)] = 0.2*np.power(20/9*resdat[(resdat >= 0) & (resdat < 0.45)], 0.2)
#             resdat[(resdat >= 0.45) & (resdat < 0.55)] = 5*(resdat[(resdat >= 0.45) & (resdat < 0.55)]-0.45)+0.2
#             resdat[(resdat >= 0.55) & (resdat < 1)] = 2/9*(resdat[(resdat >= 0.55) & (resdat < 1)]-0.55)+0.7
#             resdat[(resdat >= 1.0) & (resdat <= 10.0)] = 0.2*np.log10(resdat[(resdat >= 1.0) & (resdat <= 10.0)])+0.8
#             result = np.ma.array(resdat, mask=result.mask, copy=False)
#         if is_scalar:
#             result = result[0]
#         return result

#     def inverse(self, value):
#         if not self.scaled():
#             raise ValueError("Not invertible until scaled")

#         result, is_scalar = self.process_value(value)

#         vmin, vmax = self.vmin, self.vmax
#         resdat = result.data
#         resdat[(resdat >= 0) & (resdat < 0.2)] = 0.45*np.power(5*resdat[(resdat >= 0) & (resdat < 0.2)], 5)
#         resdat[(resdat >= 0.2) & (resdat < 0.7)] = 0.2*(resdat[(resdat >= 0.2) & (resdat < 0.7)]-0.2)+0.45
#         resdat[(resdat >= 0.7) & (resdat < 0.8)] = 4.5*(resdat[(resdat >= 0.7) & (resdat < 0.8)]-0.7)+0.55
#         resdat[(resdat >= 0.8) & (resdat <= 1.0)] = np.power(10,5*(resdat[(resdat >= 0.8) & (resdat <= 1.0)]-0.8))
#         result = np.ma.array(resdat, mask=result.mask, copy=False)
#         if is_scalar:
#             result = result[0]
#         return result    

# 使用 PyTorch 优化器
# 初始化参数
# x = torch.tensor([start_point], requires_grad=True)  # 初始点
# # 动态生成文件名与优化器
# if optimizer_name=="sgd":
#     optimizer = torch.optim.SGD([x], lr=lr, momentum=momentum)  # 使用带动量的 SGD 优化器 
#     filename = f'{args.id}_1d_SGD_lr_{lr}momentum_{momentum}.gif'
# elif optimizer_name=="adam":
#     optimizer = torch.optim.Adam([x], lr=lr, betas=(momentum,beta2))  
#     filename = f'{args.id}_1d_Adam_lr_{lr}momentum_{momentum}beta2_{beta2}.gif'
# elif optimizer_name == 'alto':
#     optimizer = create_adaMA_simple_optimizer([x], lr=lr, betas=(momentum_acc, momentum, beta2), alpha=acceleration, weight_decay=0)
#     filename = f'{args.id}_1d_ALTO_lr_{lr}momentum_{momentum}beta2_{beta2}momentum_acc_{momentum_acc}acceleration_{acceleration}.gif'
 

# 自定义颜色映射，从紫色到红色
# colors = [
#     "#7F00FF",  # Purple
#     "#4B0082",  # Indigo
#     "#0000FF",  # Blue
#     "#007FFF",  # Light Blue
#     "#00FFFF",  # Cyan
#     "#00FF7F",  # Light Cyan
#     "#4DFF00",  # Light Green
#     "#00FF00",  # Green
#     "#BFFF00",  # Light Yellow
#     "#FFFF00",  # Yellow
#     "#FFAA00",  # Light Orange
#     "#FF7F00",  # Orange
#     "#FF4D00",  # Light Red
#     "#FF0000",   # Red
    # "#FF6347",  # Tomato red
    #     "#4682B4",  # Steel blue
    #     "#32CD32",  # Lime green
    #     "#FFD700",  # Gold
    #     "#8A2BE2",  # Blue violet
    #     "#FF1493",  # Deep pink
    #     "#00CED1",  # Dark turquoise
    #     "#FF4500"   # Orange red
# ]
    

    
# cmap = LinearSegmentedColormap.from_list("purple_to_red", colors)
